{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN 做图像分类\n",
    "前面我们讲了 RNN 特别适合做序列类型的数据，那么 RNN 能不能想 CNN 一样用来做图像分类呢？下面我们用 mnist 手写字体的例子来展示一下如何用 RNN 做图像分类，但是这种方法并不是主流，这里我们只是作为举例。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于一张手写字体的图片，其大小是 28 * 28，我们可以将其看做是一个长为 28 的序列，每个序列的特征都是 28，也就是\n",
    "\n",
    "![](https://ws4.sinaimg.cn/large/006tKfTcly1fmu7d0byfkj30n60djdg5.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这样我们解决了输入序列的问题，对于输出序列怎么办呢？其实非常简单，虽然我们的输出是一个序列，但是我们只需要保留其中一个作为输出结果就可以了，这样的话肯定保留最后一个结果是最好的，因为最后一个结果有前面所有序列的信息，就像下面这样\n",
    "\n",
    "![](https://ws3.sinaimg.cn/large/006tKfTcly1fmu7fpqri0j30c407yjr8.jpg)\n",
    "\n",
    "下面我们直接通过例子展示"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-26T08:01:44.502896Z",
     "start_time": "2017-12-26T08:01:44.062542Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchvision import transforms as tfs\n",
    "from torchvision.datasets import MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-26T08:01:50.714439Z",
     "start_time": "2017-12-26T08:01:50.650872Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 定义数据\n",
    "data_tf = tfs.Compose([\n",
    "    tfs.ToTensor(),\n",
    "    tfs.Normalize([0.5], [0.5]) # 标准化\n",
    "])\n",
    "\n",
    "train_set = MNIST('./data', train=True, transform=data_tf)\n",
    "test_set = MNIST('./data', train=False, transform=data_tf)\n",
    "\n",
    "train_data = DataLoader(train_set, 64, True, num_workers=4)\n",
    "test_data = DataLoader(test_set, 128, False, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-26T08:01:51.165144Z",
     "start_time": "2017-12-26T08:01:51.115807Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 定义模型\n",
    "class rnn_classify(nn.Module):\n",
    "    def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):\n",
    "        super(rnn_classify, self).__init__()\n",
    "        self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers) # 使用两层 lstm\n",
    "        self.classifier = nn.Linear(hidden_feature, num_class) # 将最后一个 rnn 的输出使用全连接得到最后的分类结果\n",
    "        \n",
    "    def forward(self, x):\n",
    "        '''\n",
    "        x 大小为 (batch, 1, 28, 28)，所以我们需要将其转换成 RNN 的输入形式，即 (28, batch, 28)\n",
    "        '''\n",
    "        x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1，变成 (batch, 28, 28)\n",
    "        x = x.permute(2, 0, 1) # 将最后一维放到第一维，变成 (28, batch, 28)\n",
    "        out, _ = self.rnn(x) # 使用默认的隐藏状态，得到的 out 是 (28, batch, hidden_feature)\n",
    "        out = out[-1, :, :] # 取序列中的最后一个，大小是 (batch, hidden_feature)\n",
    "        out = self.classifier(out) # 得到分类结果\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-26T08:01:51.252533Z",
     "start_time": "2017-12-26T08:01:51.244612Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = rnn_classify()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "optimzier = torch.optim.Adadelta(net.parameters(), 1e-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-26T08:03:36.739732Z",
     "start_time": "2017-12-26T08:01:51.607967Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0. Train Loss: 1.858605, Train Acc: 0.318347, Valid Loss: 1.147508, Valid Acc: 0.578125, Time 00:00:09\n",
      "Epoch 1. Train Loss: 0.503072, Train Acc: 0.848514, Valid Loss: 0.300552, Valid Acc: 0.912579, Time 00:00:09\n",
      "Epoch 2. Train Loss: 0.224762, Train Acc: 0.934785, Valid Loss: 0.176321, Valid Acc: 0.946499, Time 00:00:09\n",
      "Epoch 3. Train Loss: 0.157010, Train Acc: 0.953392, Valid Loss: 0.155280, Valid Acc: 0.954015, Time 00:00:09\n",
      "Epoch 4. Train Loss: 0.125926, Train Acc: 0.962137, Valid Loss: 0.105295, Valid Acc: 0.969640, Time 00:00:09\n",
      "Epoch 5. Train Loss: 0.104938, Train Acc: 0.968450, Valid Loss: 0.091477, Valid Acc: 0.972805, Time 00:00:10\n",
      "Epoch 6. Train Loss: 0.089124, Train Acc: 0.973481, Valid Loss: 0.104799, Valid Acc: 0.969343, Time 00:00:09\n",
      "Epoch 7. Train Loss: 0.077920, Train Acc: 0.976679, Valid Loss: 0.084242, Valid Acc: 0.976661, Time 00:00:10\n",
      "Epoch 8. Train Loss: 0.070259, Train Acc: 0.978795, Valid Loss: 0.078536, Valid Acc: 0.977749, Time 00:00:09\n",
      "Epoch 9. Train Loss: 0.063089, Train Acc: 0.981093, Valid Loss: 0.066984, Valid Acc: 0.980716, Time 00:00:09\n"
     ]
    }
   ],
   "source": [
    "# 开始训练\n",
    "from utils import train\n",
    "train(net, train_data, test_data, 10, optimzier, criterion)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到，训练 10 次在简单的 mnist 数据集上也取得的了 98% 的准确率，所以说 RNN 也可以做做简单的图像分类，但是这并不是他的主战场，下次课我们会讲到 RNN 的一个使用场景，时间序列预测。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
